"""Implement secrets and variable access as used in the infrastructure repo."""

import argparse
import copy
import datetime
import functools
import http
import json
import os
import pathlib
import re
import subprocess
import sys
import tempfile
import typing
from urllib import parse
import uuid

from cki_lib import misc
from cki_lib import yaml
from cki_lib.logger import get_logger
from cki_lib.session import get_session
import requests
import requests_gssapi

LOGGER = get_logger('cki_tools.credentials.secrets')
SESSION = get_session(
    'cki_tools.credentials.secrets',
    raise_for_status=True,
    retry_args={
        'backoff_max': 30,
        # https://developer.hashicorp.com/vault/api-docs#412
        'status_forcelist': [412, 500, 502, 503],
        'total': 120,  # ~30s * 120 = 1h
    },
)
RETRIES_403 = 3

VAULT_TOKEN_PATH = pathlib.Path('~/.vault-token').expanduser()


def vault_token() -> str:
    """Return the HV token from env or config file."""
    return (os.environ.get('VAULT_TOKEN') or VAULT_TOKEN_PATH.read_text(encoding='utf8').strip())


def vault_query(
    method: str,
    path: str,
    endpoint: str = 'data',
    **kwargs: typing.Any,
) -> requests.Response:
    """Query the HV KV2 API."""
    vault_addr = os.environ['VAULT_ADDR']
    vault_mount_point = os.environ.get('VAULT_MOUNT_POINT', 'apps')
    url = parse.urljoin(vault_addr, f'v1/{vault_mount_point}/{endpoint}/cki/{path}')
    headers = {'X-Vault-Token': vault_token()}
    # retry a couple of times as HV fails spuriously with 403 even for valid secrets
    for retry in range(RETRIES_403 + 1):
        try:
            return SESSION.request(method, url, headers=headers, **kwargs)
        except requests.exceptions.HTTPError as e:
            if e.response.status_code != 403 or retry == RETRIES_403:
                raise
    raise Exception()  # silence the linter, never reached


def vault_read(path: str) -> dict[str, str]:
    """Read a secret dict from HV."""
    return typing.cast(dict[str, str], vault_query('GET', path).json()['data']['data'])


def vault_write(path: str, value: dict[str, str], *, patch: bool = False) -> None:
    """Write a secret dict to HV."""
    if patch:
        write_value = {}
        try:
            write_value = vault_read(path)
        except requests.exceptions.HTTPError as e:
            if e.response.status_code != http.HTTPStatus.NOT_FOUND:
                print(f'Unable to write to HV: {path}={value}', file=sys.stderr)
                raise
        write_value.update(value)
    else:
        write_value = value
    try:
        vault_query('PUT', path, json={'data': write_value})
    except Exception:
        print(f'Unable to write to HV: {path}={value}', file=sys.stderr)
        raise


def vault_list(path: str = '') -> list[str]:
    """Recursively list secret names from HV."""
    return list(misc.flattened([
        vault_list(f'{path}{key}') if key.endswith('/') else f'{path}{key}'
        for key in vault_query('LIST', path, endpoint='metadata').json()['data']['keys']
    ]))


def write_secrets_file(data: dict[str, typing.Any]) -> None:
    """Write the given raw data to the secrets file."""
    if not (secrets_file := os.environ.get('CKI_SECRETS_FILE')):
        raise KeyError('CKI_SECRETS_FILE env variable missing')
    pathlib.Path(secrets_file).write_text(yaml.dump(data), encoding='utf8')
    read_secrets_file.cache_clear()
    _read_secrets_file.cache_clear()


@functools.cache
def read_secrets_file() -> dict[str, typing.Any]:
    """Return the raw data from the secrets file."""
    # prevent accidental modification so edit can check for changes
    return copy.deepcopy(_read_secrets_file())


@functools.cache
def read_vars_file() -> dict[str, typing.Any]:
    """Return the raw data from the variables file."""
    return _read_yaml_file('CKI_VARS_FILE')


@functools.cache
def _read_secrets_file() -> dict[str, typing.Any]:
    """Return the raw data from the secrets file."""
    return _read_yaml_file('CKI_SECRETS_FILE')


def _read_yaml_file(env_name: str) -> dict[str, typing.Any]:
    """Return the raw secrets data."""
    if not (file_path := os.environ.get(env_name)):
        raise KeyError(f"Missing required environment variable: {env_name!r}")
    LOGGER.debug('Reading secrets file %s', file_path)
    if not isinstance(secrets := yaml.load(file_path=file_path), dict):
        raise Exception('Invalid secrets file')
    return secrets


def secret(key: str) -> typing.Any:
    """Return a decrypted secret from the secrets file."""
    all_secrets = read_secrets_file()

    path, sep, field = key.partition(':') if ':' in key else key.partition('#')

    if not (match := re.fullmatch(r'(?P<path>.*)\[(?P<conditions>.*)\]', path)):
        return _single_secret(all_secrets, path, sep, field)

    conditions = re.findall('(!?)([^,]+)', match['conditions'])
    return [
        _single_secret(all_secrets, k, sep, field)
        for k, v in all_secrets.items()
        if (k == match['path'] or k.startswith(f"{match['path']}/")) and
        all(v['meta'].get(c[1], False) == (c[0] != '!') for c in conditions)
    ]


def _single_secret(
    all_secrets: dict[str, typing.Any],
    path: str,
    sep: str,
    field: str,
) -> typing.Any:
    if (value := all_secrets.get(path)) is None:
        raise Exception(f'Secret {path} not found in secrets files')
    if sep == '':
        sep, field = ':', 'value'
    # meta from yaml
    if sep == '#':
        return value['meta'][field] if field else value.get('meta', {})
    match value.get('backend', 'hv'):
        case 'hv':
            # secrets from HashiCorp Vault
            data = vault_read(path)
            return data[field] if field else data
        case _:
            raise Exception(f'Unknown secrets backend for {path}')


def _update_secret_hv(
    data: dict[str, typing.Any],
    path: str,
    sep: str,
    field: str,
    value: typing.Any,
) -> None:
    data['backend'] = 'hv'
    match sep, field:
        case ':', '':
            if isinstance(value, str):
                value = yaml.load(contents=value)
            if value is None:
                LOGGER.warning('No support for deleting HV secrets yet, use the UI')
                data.pop('backend', None)
                return
            vault_write(path, value)
        case ':', _:
            vault_write(path, {field: value}, patch=True)
        case '', _:
            vault_write(path, {'value': value}, patch=True)
        case _:
            raise Exception(f'Unknown secrets separator {sep}')
    meta = data.setdefault('meta', {})
    meta.setdefault('active', True)
    meta.setdefault('created_at', misc.now_tz_utc().isoformat())
    meta.setdefault('deployed', True)


def edit(key: str, value: typing.Any) -> None:
    """Update a secret in a secrets file."""
    if not (secrets_file := os.environ.get('CKI_SECRETS_FILE')):
        raise Exception('CKI_SECRETS_FILE env variable missing')
    if pathlib.Path(secrets_file).exists():
        all_secrets = copy.deepcopy(_read_secrets_file())
    else:
        LOGGER.warning('Creating new secrets file %s', secrets_file)
        all_secrets = {}
    path, sep, field = key.partition(':') if ':' in key else key.partition('#')
    data = copy.deepcopy(all_secrets.get(path, {}))
    match data.get('backend', 'hv'), sep, field:
        case _, '#', '':
            if isinstance(value, str):
                value = yaml.load(contents=value)
            if value is None:
                data.pop('meta', None)
            else:
                data['meta'] = value
        case _, '#', _:
            data.setdefault('meta', {})[field] = value
        case 'hv', _, _:
            _update_secret_hv(data, path, sep, field, value)
        case _:
            raise Exception(f'Unknown secrets backend for {path}')
    if data:
        if changed := path not in all_secrets or all_secrets[path] != data:
            all_secrets[path] = data
    else:
        if changed := path in all_secrets:
            del all_secrets[path]
    if changed:
        write_secrets_file(all_secrets)


def variable(key: str) -> typing.Any:
    """Return an unencrypted variable from the variables files."""
    if (value := read_vars_file().get(key)) is None:
        raise Exception(f'Variable {key} not found in variable files')
    return value.strip() if isinstance(value, str) else value


def validate() -> tuple[set[str], set[str], set[str]]:
    """Validate that the secrets meta data matches what is stored in the backends.

    Returns a tuple with sets of secrets missing
    - secrets in HashiCorp Vault
    - meta data record in secrets file
    - required meta data (created_at)
    """
    hv_names = set(vault_list())
    all_secrets = read_secrets_file()
    meta_names = {
        name for name, data in all_secrets.items()
        if data.get('backend', 'hv') == 'hv'
    }
    incomplete_meta_names = {
        name for name, data in all_secrets.items()
        if not misc.get_nested_key(data, 'meta/created_at')
    }
    return meta_names - hv_names, hv_names - meta_names, incomplete_meta_names


def _login_oidc() -> str:
    auth = requests_gssapi.HTTPSPNEGOAuth(mutual_authentication=requests_gssapi.OPTIONAL)
    oidc_url = parse.urljoin(os.environ['VAULT_ADDR'], 'v1/auth/oidc/oidc/')
    client_nonce = str(uuid.uuid4())
    redirect_uri = 'http://localhost:8250/oidc/callback'

    response = SESSION.put(parse.urljoin(oidc_url, 'auth_url'), json={
        'client_nonce': client_nonce,
        'redirect_uri': redirect_uri,
    })
    auth_url = response.json()['data']['auth_url']

    response = SESSION.get(auth_url, auth=auth, allow_redirects=False)
    args = parse.parse_qs(parse.urlparse(response.headers['Location']).query)

    response = SESSION.get(parse.urljoin(oidc_url, 'callback?' + parse.urlencode({
        'client_nonce': client_nonce,
        'code': args['code'][0],
        'state': args['state'][0],
    })))
    return typing.cast(str, response.json()['auth']['client_token'])


def _login_approle(approle: str) -> str:
    approle_url = parse.urljoin(os.environ['VAULT_ADDR'], 'v1/auth/approle/')
    response = SESSION.put(parse.urljoin(approle_url, 'login'), json={
        'role_id': approle,
        'secret_id': os.environ['VAULT_APPROLE_SECRET_ID'],
    })
    return typing.cast(str, response.json()['auth']['client_token'])


def login(oidc: bool, approle: str, duration: datetime.timedelta) -> None:
    """Log into HashiCorp Vault."""
    token = _login_oidc() if oidc else _login_approle(approle)
    VAULT_TOKEN_PATH.write_text(token, encoding='utf8')

    token_url = parse.urljoin(os.environ['VAULT_ADDR'], 'v1/auth/token/')
    SESSION.put(parse.urljoin(token_url, 'renew-self'), json={
        'increment': int(duration.total_seconds()),
    }, headers={'X-Vault-Token': token})


def logout() -> None:
    """Log into HashiCorp Vault."""
    with misc.only_log_exceptions():
        token_url = parse.urljoin(os.environ["VAULT_ADDR"], "v1/auth/token/revoke-self")
        SESSION.put(token_url, headers={"X-Vault-Token": vault_token()})
    VAULT_TOKEN_PATH.unlink(missing_ok=True)


def secret_cli(argv: list[str] | None = None) -> int:
    """Return a decrypted secret from the secrets storage via the CLI."""
    return main(['secret'] + (argv if argv is not None else sys.argv[1:]))


def variable_cli(argv: list[str] | None = None) -> int:
    """Return an unencrypted variable from the variables files via the CLI."""
    return main(['variable'] + (argv if argv is not None else sys.argv[1:]))


def edit_cli(argv: list[str] | None = None) -> int:
    """Edit a secret in the secrets storage via the CLI."""
    return main(['edit'] + (argv if argv is not None else sys.argv[1:]))


def validate_cli(argv: list[str] | None = None) -> int:
    """Validate the stored secrets via the CLI."""
    return main(['validate'] + (argv if argv is not None else sys.argv[1:]))


def login_cli(argv: list[str] | None = None) -> int:
    """Log into the secrets storage backend via the CLI."""
    return main(['login'] + (argv if argv is not None else sys.argv[1:]))


def logout_cli(argv: list[str] | None = None) -> int:
    """Log out of the secrets storage backend via the CLI."""
    return main(['logout'] + (argv if argv is not None else sys.argv[1:]))


def _print(result: typing.Any, format_json: bool) -> None:
    if format_json:
        print(json.dumps(result))
    else:
        for single_result in misc.flattened(result):
            print(single_result)


def _read_interactive(key: str) -> str:
    """Start EDITOR/vi to edit the secret."""
    try:
        old_value = secret(key)
    except Exception:
        old_value = ''
    with tempfile.NamedTemporaryFile() as tmp:
        pathlib.Path(tmp.name).write_text(old_value, encoding='utf8')
        subprocess.run([os.environ.get('EDITOR', 'vi'), tmp.name], check=True)
        return pathlib.Path(tmp.name).read_text(encoding='utf8').strip()


def main(argv: list[str] | None = None) -> int:
    # pylint: disable=too-many-return-statements
    """Access the CKI secrets tools."""
    parser = argparse.ArgumentParser(description='Access CKI variables and secrets')

    subparsers = parser.add_subparsers(dest='type', required=True)

    parser_secret = subparsers.add_parser(
        'secret', help='Retrieve a secret value', description='Retrieve a secret value')
    parser_secret.add_argument('key', help='secret name')
    parser_secret.add_argument('--json', action='store_true', help='output in json format')

    parser_variable = subparsers.add_parser(
        'variable', help='Retrieve a variable value', description='Retrieve a variable value')
    parser_variable.add_argument('key', help='variable name')
    parser_variable.add_argument('--json', action='store_true', help='output in json format')

    parser_edit = subparsers.add_parser(
        'edit', help='Edit a secret value', description='Edit a secret value')
    parser_edit.add_argument('key', help='secret name')
    parser_edit.add_argument('value', nargs='?', help='new secret value, opens editor if missing')

    subparsers.add_parser(
        'validate', help='Validate stored secrets', description='Validate stored secrets')

    parser_login = subparsers.add_parser(
        'login', help='Log into secrets storage', description='Log into secrets storage')
    parser_login.add_argument('--duration', default='10h', help='validity of the token')
    login_group = parser_login.add_mutually_exclusive_group(required=True)
    login_group.add_argument('--oidc', action='store_true', help='login via OIDC')
    login_group.add_argument('--approle', help='login via given approle')

    subparsers.add_parser(
        'logout', help='Log out of secrets storage', description='Log out of secrets storage')

    args = parser.parse_args(argv)

    match args.type:
        case 'secret':
            _print(secret(args.key), args.json)
            return 0
        case 'variable':
            _print(variable(args.key), args.json)
            return 0
        case 'edit':
            value = _read_interactive(args.key) if args.value is None else args.value
            edit(args.key, value)
            return 0
        case 'validate':
            missing_hv, missing_meta, missing_required_meta = validate()
            for name in missing_hv:
                print(f'Missing in HashiCorp Vault: {name}')
            for name in missing_meta:
                print(f'Missing in secrets meta data: {name}')
            for name in missing_required_meta:
                print(f'Missing required meta data: {name}')
            return 2 if missing_hv or missing_required_meta else 1 if missing_meta else 0
        case 'login':
            login(args.oidc, args.approle, misc.parse_timedelta(args.duration))
            return 0
        case 'logout':
            logout()
            return 0
        case _:
            return 2


if __name__ == '__main__':
    sys.exit(main())
